import numpy as np
import continuous_scores
import discrete_scores
from collections import deque

'''
Implements local search (hill climbing) for improving parent set patterns
'''

# def local_search(d, i, K, init, duals, C_set, regu_Lambda):
#     n = d.n
#     ndata = d.ndata
#     lambda_c = duals[n:]
#     best_rc = 0
#     has_good_pattern = False # with negative rc
#     x = np.delete(init, i)

#     if d.data_type=='C':
#         g = continuous_scores.g
#         h = continuous_scores.h
#     else:
#         g = discrete_scores.g
#         h = discrete_scores.h

#     queue = deque()
#     queue.append((x.copy(), best_rc))
#     iter = 0
#     while queue and iter<500:
#         current_x, current_best_rc = queue.popleft()
#         for j in range(len(current_x)):
#             x_ = current_x.copy()
#             x_[j] = 1 - x_[j]  # change one coordinate
#             cost_x = g(d, x_, i, C_set, lambda_c, regu_Lambda)[0] - h(d, x_, i)[0]
#             if d.data_type=='C':
#                 rc = cost_x * ndata / 2 + (np.log(2 * np.pi) + 1) * ndata / 2 - duals[i]
#             else:
#                 rc = cost_x*ndata - duals[i]

#             if rc < current_best_rc:
#                 current_best_rc = rc  # update the best
#                 good_pattern = x_
#                 has_good_pattern = True
#         if has_good_pattern:
#             queue.append((good_pattern, current_best_rc))
#         iter += 1

#     if has_good_pattern:
#         good_pattern = np.insert(good_pattern, i, 0)
#         K[i] = np.vstack((K[i], good_pattern))

#     return K

def local_search(d, i, K, init, duals, C_set, regu_Lambda):
        ''' Search for patterns with better local score that varies 1 entry with the current ones'''
        n = d.n
        ndata = d.ndata
        lambda_c = duals[n:]
        if d.data_type=='C':
            g = continuous_scores.g
            h = continuous_scores.h
        else:
            g = discrete_scores.g
            h = discrete_scores.h

        best_rc = 0
        has_good_pattern = False # with negative rc
        x = np.delete(init, i)
        
        for j in range(len(x)):
            x_ = x.copy()
            x_[j] = 1-x[j] # change one coordinate
            cost_x = g(d, x_, i, C_set, lambda_c, regu_Lambda)[0] - h(d, x_, i)[0]
            if d.data_type=='C':
                rc = cost_x * ndata/2 + (np.log(2*np.pi) + 1) * ndata/2 - duals[i]
            else:
                rc = cost_x * ndata - duals[i]
            if rc < best_rc:
                x_ = np.insert(x_, i, 0)
                good_pattern = x_
                has_good_pattern = True
        if has_good_pattern:
            K[i] = np.vstack((K[i], good_pattern))
        return K